import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt

mnist = pd.read_csv('./datasets/mnist_784.csv')
X = np.array(mnist)
X = X[:,0:784]
y = mnist['class']
digit = np.zeros((10,10))

for i in range(10):
    number_index=0; dataset_index=0
    while number_index<10:
        if y[dataset_index] == i:
            digit[i,number_index] = dataset_index
            number_index = number_index + 1
        dataset_index = dataset_index + 1

fig, axs = plt.subplots(10,10,figsize=(10,10))
for i in range(10):
    for j in range(10):
        pos = int(digit[i,j])
        buffer = X[pos].reshape(28,28)
        axs[i,j].imshow(buffer, cmap = mpl.cm.binary, interpolation =
                        'nearest')
        axs[i,j].axis("off")

plt.show()
